import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class NormedLinear(nn.Module):

    def __init__(self, in_features, out_features):
        super(NormedLinear, self).__init__()
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)

    def forward(self, x):
        out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
        return out

# network module
class Cnn3D(nn.Module):
    def __init__(self, use_norm, hot_enc, n_in_dwi, n_in_st=1):
        super(Cnn3D, self).__init__()
        self.dwi_cnn = nn.Sequential(
            nn.Conv3d(n_in_dwi, 8, 3),
            nn.ReLU(),
            nn.Conv3d(8, 16, 3), 
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2),
            nn.Conv3d(16, 16, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.Conv3d(16, 32, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.MaxPool3d([1,2,2], stride=[1,2,2]),
            nn.Conv3d(32, 64, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.Conv3d(64, 32, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.MaxPool3d([1,2,2], stride=[1,2,2]),
            nn.Conv3d(32, 64, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.Conv3d(64, 32, 3, padding=[1,0,0])
        )
        self.t2ax_cnn = nn.Sequential(
            nn.Conv3d(n_in_st, 8, 3),
            nn.ReLU(),
            nn.Conv3d(8, 16, 3), 
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2),
            nn.Conv3d(16, 16, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.Conv3d(16, 32, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.MaxPool3d([1,2,2], stride=[1,2,2]),
            nn.Conv3d(32, 64, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.Conv3d(64, 32, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.MaxPool3d([1,2,2], stride=[1,2,2]),
            nn.Conv3d(32, 64, 3, padding=[1,0,0]), 
            nn.ReLU(),            
            nn.Conv3d(64, 32, 3, padding=[1,0,0])
        )
        self.dnn1 = nn.Sequential(
            nn.Linear(64, 1),
        )
        self.dnn2 = nn.Sequential(
            nn.Linear(64, 2),
        )
        self.hot_enc = hot_enc
        self.use_norm = use_norm
        if self.use_norm:
            self.linear = NormedLinear(64, 2)

    def forward(self, dwi, t2ax):
        out_dwi = self.dwi_cnn(dwi)
        out_t2ax = self.t2ax_cnn(t2ax)
        summary_dwi = out_dwi.mean([2,3,4])
        summary_t2ax = out_t2ax.mean([2,3,4])
        summary = torch.cat((summary_dwi, summary_t2ax), axis=1)
        if self.use_norm == 1:
            out = self.linear(summary)
            return F.softmax(out, dim=1), out
        elif self.hot_enc == 0:
            out = self.dnn1(summary)
            # return torch.squeeze(torch.sigmoid(out), dim=1), torch.squeeze(out, dim=1)
            return out
        elif self.hot_enc == 1:
            out = self.dnn2(summary)
            # return F.softmax(out, dim=1), out
            return out



class Cnn3D_single_path(nn.Module):
    def __init__(self, use_norm, hot_enc, n_in_dwi, n_in_st=1):
        super(Cnn3D_single_path, self).__init__()

        self.processing = nn.Sequential(
            nn.Conv3d(n_in_st, 8, 3),
            nn.ReLU(),
            nn.Conv3d(8, 16, 3), 
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2),
            nn.Conv3d(16, 16, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.Conv3d(16, 32, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.MaxPool3d([1,2,2], stride=[1,2,2]),
            nn.Conv3d(32, 64, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.Conv3d(64, 32, 3, padding=[1,0,0]), 
            nn.ReLU(),
            nn.MaxPool3d([1,2,2], stride=[1,2,2]),
            nn.Conv3d(32, 64, 3, padding=[1,0,0]), 
            nn.ReLU(),            
            nn.Conv3d(64, 32, 3, padding=[1,0,0])
        )
        self.dnn1 = nn.Sequential(
            nn.Linear(64, 1),
        )
        self.dnn2 = nn.Sequential(
            nn.Linear(64, 2),
        )
        self.hot_enc = hot_enc
        self.use_norm = use_norm
        if self.use_norm:
            self.linear = NormedLinear(64, 2)

    def forward(self, img): 
        out = self.processing(img)
        out = out.mean([2,3,4])
        if self.use_norm == 1:
            out = self.linear(out)
            return F.softmax(out, dim=1), out
        elif self.hot_enc == 0:
            out = self.dnn1(out)
            return torch.squeeze(torch.sigmoid(out), dim=1), torch.squeeze(out, dim=1)
        elif self.hot_enc == 1:
            out = self.dnn2(out)
            return F.softmax(out, dim=1), out